import argparse
import json
import os

import numpy as np
from env_utils import find_product
from filter_candidate import run as filter_candidate_run
from calc_reward import calc_reward
import dump_reward
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy.stats import linregress

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='', required=True, help='data_dir')
parser.add_argument('--input_dir', type=str, default='', required=True, help='fawefew')
parser.add_argument('--output_file_name', type=str, default='', required=True, help='')
args = parser.parse_args()
output_file_name = args.output_file_name
input_dir = args.input_dir
data_dir = args.data_dir


# candidate_file_name="candidate.jsonl"
# filter_candidate_run(data_dir=data_dir, input_dir=input_dir, output_file_name=candidate_file_name)
# dump_reward.batch_all_buy(data_dir, input_dir)

# def get_max_elo_in_single_tree(content):
#     return max([float(item['Elo']) for item in content])

# # get elos and rewards
# with open(os.path.join(data_dir, input_dir, output_file_name), "w") as fw:

#     with open(os.path.join(data_dir, input_dir, candidate_file_name), "r") as fr:
#         line = fr.readline()
#         while line:
#             line_json = json.loads(line)
#             if line_json != None and len(line_json) != 0:
#                 max_elo = get_max_elo_in_single_tree(line_json[0]['content'])
#                 max_candidate = line_json[0]['content']
#                 for i, item in enumerate(line_json):
#                     idx = item['idx']
#                     elo_reward = get_max_elo_in_single_tree(item['content'])
#                     if elo_reward > max_elo:
#                         max_elo = elo_reward
#                         max_candidate = item['content']
#                 idx, prod, attr = find_product(idx, max_candidate)
#                 reward = calc_reward(idx, prod, attr, in_dir_path=os.path.join(data_dir, input_dir), output_dir_path='')
                
#                 json.dump({
#                     "idx": idx,
#                     "Elo": max_elo,
#                     "reward": reward,
#                 }, fw)
#                 fw.write('\n')
#             line = fr.readline()

def smooth(x, y):
    f = interp1d(x,y, kind='cubic')
    x_smooth = np.linspace(x.min(), x.max(), num=300)
    y_smooth = f(x_smooth)
    return x_smooth, y_smooth
    
def regression(x, y):
    slope, intercept, r_value, p_value, std_err = linregress(x, y)

    x_fit = np.linspace(min(x), max(x), 100)
    y_fit = intercept + slope * x_fit

    return x_fit, y_fit

def run_plot(x, y):
    
    plt.plot(*smooth(x, y), color="#ff870e")
    plt.scatter(x, y, marker='^', color="#ff870e", facecolors='none')
    plt.plot(*regression(x, y), linestyle='--', color='black', label='Fit Line')
    plt.title("Reward - Elo")
    plt.grid(True)
    plt.xticks(np.linspace(0, 1, 11))
    plt.xlabel("Elo Score")
    plt.yticks(np.linspace(0.4, 0.8, 11))
    plt.ylabel("Actual Reward")
    plt.savefig("elo_reward.jpg")
    # plt.show()
    

def divide_in_part(elo_reward_list):
    
    list_size = len(elo_reward_list)
    part_size = list_size // 10  # Integer division to get the approximate size of each part
    remainder = list_size % 10  # Calculate the remainder

    split_list = [elo_reward_list[i:i+part_size] for i in range(0, list_size - remainder, part_size)]
    split_list[-1].extend(elo_reward_list[list_size - remainder:])  # Add the remaining elements to the last part

    part_pass_rate = []
    for part in split_list:
        reward_list = [item[1] for item in part]
        part_pass_rate.append(sum(reward_list) / len(reward_list))
    print(part_pass_rate)

    elo_normalized = np.linspace(0.05, 0.95, 10)
    return elo_normalized, part_pass_rate

# plot
with open(os.path.join(data_dir, input_dir, output_file_name), "r") as fr:
    line = fr.readline()
    
    elo_reward_list = []

    while line:
        line_json = json.loads(line)
        elo = line_json['Elo']
        reward = line_json['reward']
        if elo != None and reward != None:
            elo_reward_list.append((float(elo), float(reward)))
        line = fr.readline()
    
    elo_reward_list.sort(key=lambda x:x[0])

    elo_normalized, part_pass_rate = divide_in_part(elo_reward_list)

    correlation = np.corrcoef(elo_normalized, part_pass_rate)[0, 1]
    print(f"correlation: {correlation}")

    run_plot(elo_normalized, part_pass_rate)
    